{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pytorch_tabnet.tab_model import TabNetClassifier\n",
    "\n",
    "import torch\n",
    "from sklearn.preprocessing import LabelEncoder\n",
    "from sklearn.metrics import roc_auc_score\n",
    "\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "np.random.seed(0)\n",
    "\n",
    "import scipy\n",
    "\n",
    "import os\n",
    "import wget\n",
    "from pathlib import Path\n",
    "\n",
    "from matplotlib import pyplot as plt\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "os.environ['CUDA_VISIBLE_DEVICES'] = f\"1\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "torch.__version__"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Download census-income dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "url = \"https://archive.ics.uci.edu/ml/machine-learning-databases/adult/adult.data\"\n",
    "dataset_name = 'census-income'\n",
    "out = Path(os.getcwd()+'/data/'+dataset_name+'.csv')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "out.parent.mkdir(parents=True, exist_ok=True)\n",
    "if out.exists():\n",
    "    print(\"File already exists.\")\n",
    "else:\n",
    "    print(\"Downloading file...\")\n",
    "    wget.download(url, out.as_posix())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Load data and split"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train = pd.read_csv(out)\n",
    "target = ' <=50K'\n",
    "if \"Set\" not in train.columns:\n",
    "    train[\"Set\"] = np.random.choice([\"train\", \"valid\", \"test\"], p =[.8, .1, .1], size=(train.shape[0],))\n",
    "\n",
    "train_indices = train[train.Set==\"train\"].index\n",
    "valid_indices = train[train.Set==\"valid\"].index\n",
    "test_indices = train[train.Set==\"test\"].index"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Simple preprocessing\n",
    "\n",
    "Label encode categorical features and fill empty cells."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "nunique = train.nunique()\n",
    "types = train.dtypes\n",
    "\n",
    "categorical_columns = []\n",
    "categorical_dims =  {}\n",
    "for col in train.columns:\n",
    "    if types[col] == 'object' or nunique[col] < 200:\n",
    "        print(col, train[col].nunique())\n",
    "        l_enc = LabelEncoder()\n",
    "        train[col] = train[col].fillna(\"VV_likely\")\n",
    "        train[col] = l_enc.fit_transform(train[col].values)\n",
    "        categorical_columns.append(col)\n",
    "        categorical_dims[col] = len(l_enc.classes_)\n",
    "    else:\n",
    "        train.fillna(train.loc[train_indices, col].mean(), inplace=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# check that pipeline accepts strings\n",
    "train.loc[train[target]==0, target] = \"wealthy\"\n",
    "train.loc[train[target]==1, target] = \"not_wealthy\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Define categorical features for categorical embeddings"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "unused_feat = ['Set']\n",
    "\n",
    "features = [ col for col in train.columns if col not in unused_feat+[target]] \n",
    "\n",
    "cat_idxs = [ i for i, f in enumerate(features) if f in categorical_columns]\n",
    "\n",
    "cat_dims = [ categorical_dims[f] for i, f in enumerate(features) if f in categorical_columns]\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Grouped features\n",
    "\n",
    "You can now specify groups of feature which will share a common attention.\n",
    "\n",
    "This may be very usefull for features comming from a same preprocessing technique like PCA for example."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "len(features)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "grouped_features = [[0, 1, 2], [8, 9, 10]]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Network parameters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tabnet_params = {\"cat_idxs\":cat_idxs,\n",
    "                 \"cat_dims\":cat_dims,\n",
    "                 \"cat_emb_dim\":2,\n",
    "                 \"optimizer_fn\":torch.optim.Adam,\n",
    "                 \"optimizer_params\":dict(lr=2e-2),\n",
    "                 \"scheduler_params\":{\"step_size\":50, # how to use learning rate scheduler\n",
    "                                 \"gamma\":0.9},\n",
    "                 \"scheduler_fn\":torch.optim.lr_scheduler.StepLR,\n",
    "                 \"mask_type\":'entmax', # \"sparsemax\"\n",
    "                 \"grouped_features\" : grouped_features\n",
    "                }\n",
    "\n",
    "clf = TabNetClassifier(**tabnet_params\n",
    "                      )"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "X_train = train[features].values[train_indices]\n",
    "y_train = train[target].values[train_indices]\n",
    "\n",
    "X_valid = train[features].values[valid_indices]\n",
    "y_valid = train[target].values[valid_indices]\n",
    "\n",
    "X_test = train[features].values[test_indices]\n",
    "y_test = train[target].values[test_indices]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "max_epochs = 50 if not os.getenv(\"CI\", False) else 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pytorch_tabnet.augmentations import ClassificationSMOTE\n",
    "aug = ClassificationSMOTE(p=0.2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# This illustrates the behaviour of the model's fit method using Compressed Sparse Row matrices\n",
    "sparse_X_train = scipy.sparse.csr_matrix(X_train)  # Create a CSR matrix from X_train\n",
    "sparse_X_valid = scipy.sparse.csr_matrix(X_valid)  # Create a CSR matrix from X_valid\n",
    "\n",
    "# Fitting the model\n",
    "clf.fit(\n",
    "    X_train=sparse_X_train, y_train=y_train,\n",
    "    eval_set=[(sparse_X_train, y_train), (sparse_X_valid, y_valid)],\n",
    "    eval_name=['train', 'valid'],\n",
    "    eval_metric=['auc'],\n",
    "    max_epochs=max_epochs , patience=20,\n",
    "    batch_size=1024, virtual_batch_size=128,\n",
    "    num_workers=0,\n",
    "    weights=1,\n",
    "    drop_last=False,\n",
    "    augmentations=aug, #aug, None\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "# This illustrates the warm_start=False behaviour\n",
    "save_history = []\n",
    "\n",
    "# Fitting the model without starting from a warm start nor computing the feature importance\n",
    "for _ in range(2):\n",
    "    clf.fit(\n",
    "        X_train=X_train, y_train=y_train,\n",
    "        eval_set=[(X_train, y_train), (X_valid, y_valid)],\n",
    "        eval_name=['train', 'valid'],\n",
    "        eval_metric=['auc'],\n",
    "        max_epochs=max_epochs , patience=20,\n",
    "        batch_size=1024, virtual_batch_size=128,\n",
    "        num_workers=0,\n",
    "        weights=1,\n",
    "        drop_last=False,\n",
    "        augmentations=aug, #aug, None\n",
    "        compute_importance=False\n",
    "    )\n",
    "    save_history.append(clf.history[\"valid_auc\"])\n",
    "\n",
    "assert(np.all(np.array(save_history[0]==np.array(save_history[1]))))\n",
    "\n",
    "save_history = []  # Resetting the list to show that it also works when computing feature importance\n",
    "\n",
    "# Fitting the model without starting from a warm start but with the computing of the feature importance activated\n",
    "for _ in range(2):\n",
    "    clf.fit(\n",
    "        X_train=X_train, y_train=y_train,\n",
    "        eval_set=[(X_train, y_train), (X_valid, y_valid)],\n",
    "        eval_name=['train', 'valid'],\n",
    "        eval_metric=['auc'],\n",
    "        max_epochs=max_epochs , patience=20,\n",
    "        batch_size=1024, virtual_batch_size=128,\n",
    "        num_workers=0,\n",
    "        weights=1,\n",
    "        drop_last=False,\n",
    "        augmentations=aug, #aug, None\n",
    "        compute_importance=True # True by default so not needed\n",
    "    )\n",
    "    save_history.append(clf.history[\"valid_auc\"])\n",
    "\n",
    "assert(np.all(np.array(save_history[0]==np.array(save_history[1]))))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# plot losses\n",
    "plt.plot(clf.history['loss'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# plot auc\n",
    "plt.plot(clf.history['train_auc'])\n",
    "plt.plot(clf.history['valid_auc'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# plot learning rates\n",
    "plt.plot(clf.history['lr'])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Predictions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "preds = clf.predict_proba(X_test)\n",
    "test_auc = roc_auc_score(y_score=preds[:,1], y_true=y_test)\n",
    "\n",
    "\n",
    "preds_valid = clf.predict_proba(X_valid)\n",
    "valid_auc = roc_auc_score(y_score=preds_valid[:,1], y_true=y_valid)\n",
    "\n",
    "print(f\"BEST VALID SCORE FOR {dataset_name} : {clf.best_cost}\")\n",
    "print(f\"FINAL TEST SCORE FOR {dataset_name} : {test_auc}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# check that best weights are used\n",
    "assert np.isclose(valid_auc, np.max(clf.history['valid_auc']), atol=1e-6)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "clf.predict(X_test)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Save and load Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# save tabnet model\n",
    "saving_path_name = \"./tabnet_model_test_1\"\n",
    "saved_filepath = clf.save_model(saving_path_name)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# define new model with basic parameters and load state dict weights\n",
    "loaded_clf = TabNetClassifier()\n",
    "loaded_clf.load_model(saved_filepath)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "loaded_preds = loaded_clf.predict_proba(X_test)\n",
    "loaded_test_auc = roc_auc_score(y_score=loaded_preds[:,1], y_true=y_test)\n",
    "\n",
    "print(f\"FINAL TEST SCORE FOR {dataset_name} : {loaded_test_auc}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "assert(test_auc == loaded_test_auc)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "loaded_clf.predict(X_test)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Global explainability : feat importance summing to 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "clf.feature_importances_"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Local explainability and masks"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "explain_matrix, masks = clf.explain(X_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig, axs = plt.subplots(1, 3, figsize=(20,20))\n",
    "\n",
    "\n",
    "for i in range(3):\n",
    "    axs[i].imshow(masks[i][:50])\n",
    "    axs[i].set_title(f\"mask {i}\")\n",
    "    axs[i].set_xticklabels(labels = features, rotation=45)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# XGB"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "from xgboost import XGBClassifier\n",
    "\n",
    "clf_xgb = XGBClassifier(max_depth=8,\n",
    "    learning_rate=0.1,\n",
    "    n_estimators=1000,\n",
    "    verbosity=0,\n",
    "    silent=None,\n",
    "    objective='binary:logistic',\n",
    "    booster='gbtree',\n",
    "    n_jobs=-1,\n",
    "    nthread=None,\n",
    "    gamma=0,\n",
    "    min_child_weight=1,\n",
    "    max_delta_step=0,\n",
    "    subsample=0.7,\n",
    "    colsample_bytree=1,\n",
    "    colsample_bylevel=1,\n",
    "    colsample_bynode=1,\n",
    "    reg_alpha=0,\n",
    "    reg_lambda=1,\n",
    "    scale_pos_weight=1,\n",
    "    base_score=0.5,\n",
    "    random_state=0,\n",
    "    seed=None,)\n",
    "\n",
    "clf_xgb.fit(X_train, y_train,\n",
    "        eval_set=[(X_valid, y_valid)],\n",
    "        early_stopping_rounds=40,\n",
    "        verbose=10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "preds = np.array(clf_xgb.predict_proba(X_valid))\n",
    "valid_auc = roc_auc_score(y_score=preds[:,1], y_true=y_valid)\n",
    "print(valid_auc)\n",
    "\n",
    "preds = np.array(clf_xgb.predict_proba(X_test))\n",
    "test_auc = roc_auc_score(y_score=preds[:,1], y_true=y_test)\n",
    "print(test_auc)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.13"
  },
  "toc": {
   "base_numbering": 1,
   "nav_menu": {},
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": false,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": false,
   "toc_position": {},
   "toc_section_display": true,
   "toc_window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
